{- 

    Copyright © 2011 - 2016, Ingo Wechsung
    All rights reserved.

    Redistribution and use in source and binary forms, with or
    without modification, are permitted provided that the following
    conditions are met:

        Redistributions of source code must retain the above copyright
        notice, this list of conditions and the following disclaimer.

        Redistributions in binary form must reproduce the above
        copyright notice, this list of conditions and the following
        disclaimer in the documentation and/or other materials provided
        with the distribution. 
        
        Neither the name of the copyright holder
        nor the names of its contributors may be used to endorse or
        promote products derived from this software without specific
        prior written permission.

    THIS SOFTWARE IS PROVIDED BY THE
    COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
    WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
    PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER
    OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
    SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
    LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
    USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
    AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
    LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
    IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
    THE POSSIBILITY OF SUCH DAMAGE.

     -}

{--
    Simple utility functions for the frege compiler
 -}

module frege.compiler.Utilities
    where

import frege.Prelude except (error, print, println, break, <+>)
import Data.TreeMap as TM(TreeMap, lookup, each, insert,
                    union, including, contains, keys, values, fromKeys)
import Data.List as DL(partitioned, sortBy, minimumBy, \\)

import Lib.PP(fill, break, pretty, text, nest, msgdoc, <+>, <>, DOCUMENT)

-- import Compiler.enums.Flags 
import Compiler.enums.TokenID(defaultInfix, VARID)
import Compiler.enums.RFlag(RState)
import Compiler.enums.Visibility
import Compiler.enums.SymState


import  Compiler.types.Positions
import  Compiler.types.Tokens
import  Compiler.types.Strictness
import  Compiler.types.NSNames
import  Compiler.types.SNames
import  Compiler.types.Packs
import  Compiler.types.QNames
import  Compiler.types.Types
import  Compiler.types.Patterns
import  Compiler.types.Symbols
import  Compiler.types.Expression
import  Compiler.types.Global as G
import  Compiler.classes.Nice(Nice)
import  Compiler.common.Errors as E()
import  Compiler.common.Resolve as R
import  Compiler.instances.Nicer 
import  Compiler.common.SymbolTable
import  Compiler.common.Binders
import  Compiler.common.Types as CT

--- Associate a java class name with a type
nativeType nativ tname = do
    g <- getST
    case g.javaEnv.lookup nativ of
        Just (cs,ls) -> do
            let newl = tname : filter (tname!=) ls
            -- doio $ traceStrLn ("old: " ++ show ls)
            -- doio $ traceStrLn ("new: " ++ show newl)
            changeST Global.{javaEnv <- insert nativ (cs, newl) }
        nothing -> do
            -- doio $ traceStrLn ("new: " ++ show [tname])
            changeST Global.{javaEnv <- insert nativ ([], [tname])}


--- Associate a java class name with a list of supertypes
nativeSupers nativ classes = do
    g <- getST
    let !strings = map (_.replaceAll '\$' ".") . map Class.getName $ classes
    case g.javaEnv.lookup  nativ of
        Just (_,ls) -> do
            changeST Global.{javaEnv <- insert nativ (strings, ls) }
        nothing -> do
            changeST Global.{javaEnv <- insert nativ (strings, [])}


--- Find the types that are associated with a native type.
typesOfNativ :: String -> Global -> [QName]
typesOfNativ nativ g = case g.javaEnv.lookup  nativ of
    Just (_,ls)  -> ls
    nothing -> []


--- Find the proper super types that are associated with a native type.
--- (i.e. the type itself is not in the list)
supersOfNativ :: String -> Global -> [String]
supersOfNativ nativ g = case  g.javaEnv.lookup nativ of
    Just (cs, _)  -> cs
    nothing -> []


--- find a specific symbol or die
findC qname = do
    g <- getST
    case g.findit qname of
        Just (symc@SymC {pos}) -> stio symc
        Just sym -> E.fatal sym.pos (fill (break ("looked for class " ++ qname.nice g ++ ", found "
                                    ++ sym.nice g)))
        Nothing -> E.fatal Position.null (fill (break ("looked for class " ++ qname.nice g ++ ", found Nothing")))


findI qname = do
    g <- getST
    case g.findit qname of
        Just (symc@SymI {pos}) -> stio symc
        Just sym -> E.fatal sym.pos (fill (break ("looked for instance " ++ qname.nice g ++ ", found "
                                    ++ sym.nice g)))
        Nothing -> E.fatal Position.null (fill (break ("looked for instance " ++ qname.nice g ++ ", found Nothing")))


findT qname = do
    g <- getST
    case g.findit qname of
        Just (symc@SymT {pos}) -> stio symc
        Just sym -> E.fatal sym.pos (fill (break("looked for type " ++ qname.nice g ++ ", found "
                                    ++ sym.nice g)))
        Nothing -> E.fatal Position.null (fill (break ("looked for type " ++ qname.nice g ++ ", found Nothing")))


findV qname = do
    g <- getST
    case g.findit qname of
        Just (symc@SymV {pos}) -> stio symc
        Just sym -> E.fatal sym.pos (fill (break ("looked for function " ++ qname.nice g ++ ", found "
                                    ++ sym.nice g)))
        Nothing -> E.fatal Position.null (fill (break ("looked for function " ++ qname.nice g ++ ", found Nothing")))


findVD qname = do
    g <- getST
    case g.findit qname of
        Just (symc@SymV {pos}) -> stio symc
        Just (symc@SymD {pos}) -> stio symc
        Just sym -> E.fatal sym.pos (fill (break ("looked for function or constructor " ++ qname.nice g ++ ", found "
                                    ++ sym.nice g)))
        Nothing -> E.fatal Position.null (fill (break ("looked for function " ++ qname.nice g ++ ", found Nothing")))


findD qname = do
    g <- getST
    case g.findit qname of
        Just (symc@SymD {pos}) -> stio symc
        Just sym -> E.fatal sym.pos (fill (break ("looked for constructor " ++ qname.nice g ++ ", found "
                                    ++ sym.nice g)))
        Nothing -> E.fatal Position.null (fill (break ("looked for constructor " ++ qname.nice g ++ ", found Nothing")))


{-- interpret NS.T or T as type name -}
nstname :: SName -> Global -> Maybe QName
nstname s g = case s of
    Simple t   -> Just (TName g.thisPack t.value)
    With1 ns t -> case g.namespaces.lookup (NSX ns.value) of
        Just p  -> Just (TName p t.value)
        Nothing -> Nothing
    _           -> Nothing


--- return a list of 'TVar's that are unbound in the type
--- the first argument is the list of 'TVar' names bound outside this rho
freeTVars :: [String] -> RhoT a -> [TauT a]
freeTVars bound rho = values (freeRhoTVars bound TreeMap.empty rho)


--- like 'freeTVars', but returns only 'TVar'  names
freeTVnames :: [String] -> RhoT a -> [String]
freeTVnames bound rho = keys (freeRhoTVars bound TreeMap.empty rho)


freeRhoTVars bound collected (RhoFun ctx (ForAll bndleft rho1) rho2) =
    freeRhoTVars (map _.var bndleft ++ bound) (freeRhoTVars bound ctxcoll rho2) rho1
        where ctxcoll = freeCtxTVars bound collected ctx
freeRhoTVars bound collected (RhoTau ctx tau) = freeTauTVars bound ctxcoll tau
        where ctxcoll = freeCtxTVars bound collected ctx


--- @freeCtxTVars bnd coll ctxs@
--- add the 'TVar's that occur in the contexts in @ctxs@ and are not elemnt of @bnd@ to 'OrdSet' @coll@
freeCtxTVars bnd coll cs = fold ctxTVars coll cs
    where
        ctxTVars t ctx = freeTauTVars bnd t (Context.tau ctx)


freeTauTVars :: [String] -> TreeMap String (TauT β) -> TauT β -> TreeMap String (TauT β)
freeTauTVars bound collected (tv@TVar{var,kind})
    | isJust tv.wildTau = coll
    | var `elem` bound = coll
    | otherwise = insert var tv coll
    where
        coll | KGen ts ← kind = fold (freeTauTVars (var:bound)) collected ts
             | otherwise     = collected
freeTauTVars bound collected (TApp a b)
                = freeTauTVars bound (freeTauTVars bound collected a) b
freeTauTVars bound collected (TSig sig) = freeRhoTVars (sig.vars ++ bound) collected sig.rho
freeTauTVars _ collected _ = collected


--- return a list of constructors in this environment ordered by constructor number
envConstructors :: Symtab -> [Symbol]
envConstructors env = sortBy (comparing Symbol.cid) [ sy | (sy::Symbol) <- values env, sy.{cid?} ]


--- provide a new Position for a Pattern
pPos newp (p@PCon{}) = do
    pats <- mapSt (pPos newp) p.pats
    return p.{pats, pos=newp}
pPos newp (p@PConFS{}) = do
    let ret (nm, pat) = do p <- pPos newp pat; return (nm, p)
    fields <- mapSt ret p.fields
    return p.{fields, pos=newp}
pPos newp (p@PAt{pos, var, pat}) = do
    npat ← pPos newp pat
    return p.{pat=npat, pos = newp.change VARID var}
pPos newp (p@PVar{pos, var}) = return p.{pos = newp.change VARID var}
pPos newp p
    | p.{pat?} = pPos newp p.pat >>= return .  p.{pat=}
    | p.{pos?} = return p.{pos=newp}
    | otherwise = do
        g <- getST
        E.error newp (text ("pPos: strange pattern: " ++ nice p g)) 
        return p


--- provide (a new) uid for a Pattern
pReturn (p@PCon{}) = do
    pats <- mapSt pReturn p.pats
    return p.{pats}
pReturn (p@PConFS{}) = do
    let ret (nm, pat) = do p <- pReturn pat; return (nm, p)
    fields <- mapSt ret p.fields
    return p.{fields}
pReturn (p@PUser{pat}) = do
    pat <- pReturn pat
    return p.{pat}
pReturn (p@PAnn{pat}) = do
    pat <- pReturn pat
    return p.{pat}
pReturn (p@PLit{}) = return p
pReturn (p@PVar{}) = do
    uid <- uniqid
    return p.{uid}
pReturn (p@PMat{}) = do
    uid <- uniqid
    return p.{uid}
pReturn (p@PAt{}) = do
    uid <- uniqid
    pat <- pReturn p.pat
    return p.{pat,uid}


--- the local 'QName' that belongs to this 'Pattern', it must have fields 'Pattern.uid' and 'Pattern.var'.
pVarLocal :: PatternT a -> QName
pVarLocal p  = Local p.uid p.var


{-- @patLocal pos name@ creates a local variable symbol from pos and name -}
patLocal pos uid name = vSym pos (Local uid name)


--- set uid for a local symbol
setuid uid = Symbol.{sid=uid, name <- QName.{uid}}


{--
    Make a new local symbol from  a 'PVar' and enters it in the symbol table.
    Will fail if argument is not a 'PVar'
    -}
mkLocal :: Pattern -> StG Symbol
mkLocal pvar = do
    let sym = patLocal pvar.pos (abs pvar.uid) pvar.var
    enter sym
    stio sym

--- make a completely fresh var
freshVar pos = do
    uid ← uniqid
    let p = PVar{pos, uid, var="µ$"}
    mkLocal p
    pure p

--- update the local names uids in an expression that match one of the symbols by name
replaceLocals :: [Symbol] -> Expr -> StG (Either Expr Expr)
replaceLocals syms (v@Vbl {name = Local 0 s}) = 
    case DL.find (\sym -> sym.name.base == s) syms of
        Just sym -> stio (Right v.{name = sym.name})
        other    -> stio (Right v)
replaceLocals syms x = stio (Left x)


{--
 *  @vSym pos name@ creates a 'SymV' symbol with all fields except 'Symbol.pos'
 *  and 'Symbol.name' set to standard values.
 * If the name is a 'Local' one, the 'Symbol.sid' is set to the 'QName.uid'
 -}
vSym pos name = SymV {pos,
                    sid= if QName.{uid?} name then name.uid else 0,
                    name, vis=Private, doc=Nothing,
                    typ=pSigma, expr=Nothing, nativ=Nothing,
                    pur=false, anno=false, exported=false, 
                    state=Unchecked,
                    strsig = Strictness.U, depth = 0, 
                    rkind = RState.empty,
                    throwing = [], over = [], gargs = [],
                    op = defaultInfix }


--- check how man lambdas are nested inside each other
lambdaDepth (Lam {ex}) = 1 + lambdaDepth ex
lambdaDepth ex = 0


{--
 * make a rho type from a tau type
 -}
rhoTau tau
    | Just (a,b) <- Tau.getFun tau = RhoFun [] (ForAll [] (rhoTau a)) (rhoTau b)
rhoTau tau        = RhoTau [] tau


{--
 * validSigma - make sure the user provided 'Sigma' is valid:
 * 1) if there are no variables in forall, look up all free variables and provide them
 * 2) if there are variables in forall, make sure there are no free variables in the type
 * 3) make sure the variables in forall actually do appear in the type
 -}
validSigma :: SigmaT a -> StG (SigmaT a)
validSigma (ForAll [] rho) = do
    let vs = freeTVars [] rho
    rho <- validRho (map Tau.var vs) rho
    stio (ForAll vs rho)
validSigma sig = validSigma1 [] sig

type SigmaEnv a = TreeMap String (TauT a)

validSigma1 outer (ForAll bound rho) = do
    g <- getST
    let inner = map _.var bound
        vs = freeTVnames (DL.union inner outer) rho
        fs = freeTVnames (outer \\ inner) rho
        ov = [ v | TVar{var=v} <- bound, v `notElem` fs ]
    -- logmsg TRACE5 (rhoPos rho) ("vs: " ++ display vs)
    -- logmsg TRACE5 (rhoPos rho) ("fs: " ++ display fs)
    -- logmsg TRACE5 (rhoPos rho) ("ov: " ++ display ov)
    unless (null vs)
        (E.error (rhoPos rho) (fill (break ("free type variable(s) " ++ joined ", " vs ++ " not allowed"))))
    unless (null ov)
        (E.error (rhoPos rho) (fill (break ("bound type variable(s) " ++ joined ", " ov
            ++ " not appearing in type"))))
    rho <- validRho (vs DL.`union` (inner DL.`union` outer)) rho
    stio (ForAll bound rho)


{--
 * check that contexts only contain type variables referenced in rho
 * Since well formed types contain no free type variables,
 * this are exactly the types bound in the outer *forall*s
 *
 * While we are using the version 2 compiler, we must also
 * extract the contexts from the type. For this purpose, we assume
 * sigmas to be of the form @Sigma [] rho@
 -}
-- validRho :: [String] -> RhoT Ord:a -> StG (RhoT a)
validRho bound rho = do
        foreach (Rho.context rho) check
        case rho of
            RhoFun _ sig r -> do
                 sig <- validSigma1 bound sig
                 r   <- validRho    bound r
                 stio rho.{sigma=sig, rho=r}
            _ -> stio rho
    where
        check (Ctx pos qname tau) = do
            let vars = keys (freeTauTVars [] empty tau)
                bads = filter (`notElem` bound) vars
            unless (null bads) do
                E.error pos (fill (break ("context must not name type variables " ++ joined "," bads)))
            stio ()


{--
 * check a sigma and translate it to 'QName' form.
 -}
transSigma sigma = validSigma sigma >>= transSigma1 empty


{--
    Check a sigma with outer bound vars.

    > transSigma1 outer (ForAll bound rho) 

    Pre: All type variables free in rho appear either in @outer@ or in @bound@
    
-}
private transSigma1 ∷ SigmaEnv QName → SigmaS → StG Sigma
private transSigma1 outer sigma = do
    -- sigma <- validSigma1 outer sigma
    case sigma of 
        ForAll bound rho -> do
                inner ←  transBounds sigma.bound
                let env = fold (\tm tv → tm.insertS tv.var tv) outer inner
                    unshadowedOuter = fold (\tm tv → tm.delete tv.var) outer inner
                rho <- transRho env rho
                let freeTV = freeTVars (keys unshadowedOuter) rho   -- same as inner?
                    innerMap = fold (\tm tv → tm.insertS tv.var tv) TreeMap.empty inner
                    newbound = [ maybe tv id (innerMap.lookupS tv.var) | tv ← freeTV ]
                    constraints rho = case rho of 
                        RhoTau{} -> (rho.context, rho.{context=[]})
                        RhoFun{} -> (rho.context ++ subctx, rho.{context=[], rho=subrho})
                                    where (subctx, subrho) = constraints rho.rho
                    (ctx, rrho) = constraints rho
                    nrho = rrho.{context=ctx} 
                case rho of
                    RhoTau _ _ -> return (ForAll newbound nrho)
                    _          -> return (ForAll newbound (unTau nrho))


{--
 * transform a string context to q 'QName' context
 -}
transCtx (Ctx pos name tau) = do
        name <- defaultXName pos (TName pPreludeBase "Eq") name
        tau  <- transTau tau >>= forceTau
        stio (Ctx pos name tau)


{--
 * check a rho and translate it to 'QName' form
 -}
transRho :: SigmaEnv QName -> RhoS -> StG Rho
transRho outer (RhoFun ctx sig rho) = do
        sig <- transSigma1 outer sig
        rho <- transRho outer rho
        ctx  <- mapSt transCtx ctx
        stio (RhoFun ctx sig rho)
transRho outer (RhoTau ctx tau) = do
        sig  <- transTau tau
        ctx  <- mapSt transCtx ctx
        stio $ sig.rho.{context <- (++ ctx)}

--- translate a 'KindT' 'Sname' to 'KindT' 'QName'
transKind KType = pure KType
transKind KVar  = pure KVar
transKind (KApp a b) = liftM2 KApp  (transKind a) (transKind b)
transKind (KGen ts) = do
        ss ← mapM trans1 ts
        pure (KGen ss)
    where 
        trans1 t = do
            s ← transTau t
            case sigmaAsTau s of
                Just t  → pure t
                Nothing → do
                    E.error (getpos t) (text "invalid type in a kind")
                    pure TVar{pos=getpos t, kind=KType, var="wrong"}

--- Translate a 'Sigma' bound.
--- This is really a shortcut way to translate a type variable.
--- Will die of pattern match failure if argument doesn't have a @kind@ field.
transBound ∷ TauS → StG Tau
transBound tv  = tv.{kind=} <$> transKind tv.kind

--- translate 'Sigma' bounds
transBounds ∷ [TauS] → StG [Tau]
transBounds = mapM transBound

{--
    check a tau and translate it to 'QName' form
 -}
transTau :: TauT SName -> StG (SigmaT QName)
transTau (TVar{pos,var,kind}) = do
        k ← transKind kind
        pure (tauAsSigma (TVar {pos,var,kind = k}))

transTau (t@Meta _) = do    -- Meta must not happen at this time
    g <- getST
    E.fatal (tauPos t) (text("meta type must not occur at this time: " ++ t.nice g))
transTau (con@TCon {pos,name}) = transTApp [con]
transTau (t@TApp _ _) = transTApp t.flat
transTau (TSig s)     = do
        sig <- transSigma s
        case sigmaAsTau sig of
            Just t -> return (tauAsSigma t)
            Nothing -> return sig


transTApp (con:as) = do
     case con of
        TCon {pos,name} -> do
            tname <- resolveTName pos name
            g <- getST
            case tname of
                Nothing -> unit
                -- check if this is really a type constructor
                Just tn -> case g.findit tn of
                    Just (SymT {name,typ=ForAll bs _}) -> do
                            let ncon = TCon {pos, name}
                            as <- mapSt transTau as
                            appTauSigmas ncon as
                            -- checkTApp partial (ncon:as)
                    Just (alias@SymA {typ})
                        | ForAll _ (RhoTau _ tau) <- typ,
                          length as >= length alias.vars = do
                                as <- mapSt transTau as  -- partial args allowed in alias
                                tas <- mapSt forceTau as
                                let env = TM.fromList (zip (map Tau.var alias.vars) tas)
                                    ras = drop (length alias.vars) as
                                let nt = substTau env tau
                                -- make sure errors are flagged on the right position
                                let posnt = case Tau.flat nt of
                                        (tcon:xs) | tcon.{pos?} = Tau.mkapp tcon.{pos} xs
                                        other -> nt
                                appTauSigmas posnt ras
                        | ForAll _ (RhoTau _ tau) <- typ,
                          length as < length alias.vars = do
                                E.error pos (msgdoc ("apply " ++ alias.nice g ++ " to at least "
                                        ++ show (length alias.vars)
                                        ++ " type arguments"))
                                unit
                        {-
                            When we have a function alias, the number of type arguments
                            must match exactly, because the result will be a sigma,
                            and we can't apply remaining arguments to this.
                        -}
                        | length as == length alias.vars = do
                            sargs <- mapSt transTau as
                            targs <- mapSt forceTau sargs
                            let env = TM.fromList (zip (map Tau.var alias.vars) targs)
                            return (substSigma env alias.typ)
                        | otherwise = do E.error pos (msgdoc("Apply " ++ alias.nice g
                                            ++ " to exactly " ++ show alias.vars.length
                                            ++ " type arguments."))
                                         unit
                    Just sym -> do
                        E.error pos (msgdoc("`" ++ tn.nice g ++ "` is not a type constructor."))
                        unit
                    Nothing -> do
                        E.error pos (msgdoc("Can't find `" ++ tn.nice g ++ "`"))
                        unit
          where unit = return (tauAsSigma (TCon {pos, name=TName pPreludeBase "()"}))
        other -> do
            con <- transTau con >>= forceTau
            as  <- mapSt transTau as
            appTauSigmas con as
transTApp _ = undefined                   -- must only be used with tau.flat


--- wrap a 'Tau' in a 'Sigma'
tauAsSigma (TSig s) = s
tauAsSigma t = ForAll [] (unTau (RhoTau [] t))


{--
    Extract a 'Tau' from a 'Sigma' of the form:
    
    > ForAll [] (RhoTau [] t)
    > ForAll [] (RhoFun [] a b) 
-}
sigmaAsTau (ForAll [] (RhoTau [] t)) = Just t
sigmaAsTau (ForAll [] (r@RhoFun{})) | RhoTau [] t <- tauRho r  = Just t
sigmaAsTau _ = Nothing


forceTau sig = case (sigmaAsTau sig) of
    Just t  -> return t
    Nothing -> do
        g <- getST
        E.error sig.getpos (msgdoc ("illegal type  " ++ nicer sig g 
            ++ ", forall types are not allowed here."))
        return TCon{pos=sig.getpos, name=TName pPreludeBase "()"}


{--
    > appTauSigmas tau sigs  
    
    Apply a list of 'Sigma's to a type constructor
    
    This will result in errors whenever one of 'Sigma's is not really a 'Tau'
-}
appTauSigmas tau sigs = foldM appTauSig tau sigs >>= return . tauAsSigma
    where
        appTauSig tau sig = maybe noway (return . TApp tau) (sigmaAsTau sig)
            where 
                noway = do
                    g <- getST
                    E.error (getpos tau) (msgdoc ("illegal type argument  " ++ nicer sig g 
                        ++ " in type application of  " ++ nicer tau g
                        ++ ", forall types are not allowed here."))
                    return tau


{-- 
    Print a warning or hint regarding a symbol only if the documentation text
    does not start with "nowarn:"
    -} 
symWarning :: (Position -> DOCUMENT -> StG ()) -> Symbol -> DOCUMENT -> StG ()
symWarning warn sym msg = do
     case sym.doc of
        Just ´^\s*nowarn:´ -> return ()
        other -> warn sym.pos msg


{- ################# functions introduced through Classes.fr ############## -}
{-- @isSuper x g y@ is true if /x/ is the same as /y/ or any direct superclass /s/ of
    /y/ in environment /g/ @isSuper x g s@ is true

    For example @isSuper Eq g Ord@ will be true.
 -}



isSuper x g y | x == y = true
              | ysym@Just (SymC {supers}) <- Global.findit g y = any (isSuper x g) supers
              | otherwise = false


{- ################## functions introduced through Transform.fr ############ -}
{--
 * fold an 'Expr'
 *
 * @foldEx b f a ex@ passes accumulator _a_ and each subexpression of _ex_
 * (including _ex_ itself, of course) to _f_
 * which either returns 'StIO' ('Left' _na_) or 'StIO' ('Right' _na_) where _na_ is
 * the updated accumulator.

 * In the latter case, the visitor function _f_ alone is responsible for recursion
 * into the subexpression passed to it, 'foldEx' will abstain from recursing in this
 * case. If however, a 'Left' value is returned, 'foldEx' will continue with the
 * subexpressions of the expression just passed.

 * 'foldEx' will make sure that the correct environment is in place when subexpressions
 * are visited.

 * If _b_ is true, 'foldEx' will also traverse the right hand sides of *@let@* definitions.

 * Example: count the number of constructors in an expression, but do
 * not include those on the right hand side of let-definitions

 > conCount a (Con {pos}) = stio (Right (a+1))
 > conCount a _           = stio (Left a)
 > ncons ex = foldEx false conCount 0 ex

 * Example: find out maximal nesting depth

 > nestDepth a _ = do
 >          g <- getST
 >          stio (Left (max a (length g.env)))
 > nestEx ex = do
 >          a <- foldEx true nestDepth 0 ex
 >          g <- getST
 >          stio (a - length g.env)
 -}
foldEx :: Bool -> (a -> Expr -> StG (Either a a)) -> a -> Expr -> StG a
foldEx b f a ex = do
        result <- f a ex
        case result of
            Right a  -> stio a
            Left  a  -> case ex of
                Vbl {pos} = stio a
                Con {pos} = stio a
                Lit {pos} = stio a
                App g x _ = do
                    a <- foldEx b f a g
                    foldEx b f a x
                Let {env,ex}
                    | b = do
                        syms <- mapSt findV env
                        xs <- sequence [ x | SymV {expr=Just x} <- syms ]
                        a <- foldSt (foldEx b f) a xs
                        foldEx b f a ex
                    | otherwise = foldEx b f a ex
                Lam {ex} = foldEx b f a ex
                Ifte {cnd,thn,els} = do
                    a <- foldEx b f a cnd
                    a <- foldEx b f a thn
                    foldEx b f a els
                Mem {ex} = foldEx b f a ex
                Case {ex,alts} = do
                    a <- foldEx b f a ex
                    foldSt foldAlt a alts
                  where foldAlt a (CAlt {ex}) = foldEx b f a ex
                Ann {ex} = foldEx b f a ex


{--
 * @mapEx b f ex@ applies function _f_ to 'Expr' _ex_ and, depending on the result, to
 * all subexpressions of _ex_ and reconstructs the expressions with the results.
 *
 * if b is true, 'mapEx' descends also in the right hand sides of let definitions.
 *
 * _f_ is a function that returns ('Left' _x_) or ('Right' _x_) in the 'StIO' monad.
 * In the latter case, expression _x_ in the 'StIO' monad is the result of 'mapEx'.
 * When _f_ returns a 'Left' value, however, 'mapEx' continues recursively into the
 * returned expression.
 *
 * 'mapEx' makes sure that subexpressions of expressions with environments find the
 * correct list of local environment in the global state.
 -}
mapEx :: Bool -> (Expr -> StG (Either Expr Expr)) -> Expr -> StG Expr
mapEx b f x = do
    xx <- f x
    case xx of
        Right fx -> stio fx
        Left  fx -> case fx of
            Vbl {pos} = stio fx
            Con {pos} = stio fx
            Lit {pos} = stio fx
            App g x typ = do
                g <- mapEx b f g
                x <- mapEx b f x
                stio (App g x typ)
            Let {env,ex,typ}
                | b = do
                    syms <- mapSt findV env
                    let xs = [ sy | sy@SymV {expr=Just _} <- syms ]
                    foreach xs mapsub
                    ex  <- mapEx b f ex
                    stio (Let {env,ex,typ})
                | otherwise = do
                    ex <- mapEx b f ex
                    stio (Let {env,ex,typ})
                where mapsub (sy@SymV {expr=Just dx}) = do
                        x <- dx
                        x <- mapEx b f x
                        changeSym sy.{expr=Just (return x)}
                      mapsub sy = do
                        g <- getST
                        E.fatal sy.pos (text ("mapEx: strange symbol in let def rhs: "
                                    ++ sy.nice g))
            Lam {pat,ex,typ} = do
                ex <- mapEx b f ex
                stio (Lam {pat,ex,typ})
            Ifte {cnd,thn,els,typ} = do
                cnd <- mapEx b f cnd
                thn <- mapEx b f thn
                els <- mapEx b f els
                stio (Ifte cnd thn els typ)
            Mem {ex,member,typ} = do
                ex <- mapEx b f ex
                stio (Mem {ex,member,typ})
            Case {ckind,ex,alts,typ} = do
                ex <- mapEx b f ex
                alts <- mapSt (mapAlt b f) alts
                stio (Case {ckind,ex,alts,typ})
              where mapAlt b f (CAlt {pat,ex}) = do
                        ex <- mapEx b f ex
                        stio (CAlt {pat,ex})
            Ann {ex,typ} = do
                ex <- mapEx b f ex
                stio (Ann {ex,typ})


--- a variant of 'mapEx' that passes all outer lambdas and invokes 'mapEx' on the body
mapExBody b f (Lam {pat,ex,typ}) = do
        ex <- mapExBody b f ex
        stio (Lam {pat,ex,typ})
mapExBody b f x = mapEx b f x


--- @localSyms ex@ returns the set of all local uids mentioned in _ex_
localSyms ex = foldEx true collect TreeMap.empty ex where
    collect acc (Vbl {name=nm@Local {uid}}) = do
        --sym <- findV nm
        stio (Left (acc `including` uid))
    collect acc _ = stio (Left acc)


--- replace a set of uids
replUids :: TreeMap Int Int -> Expr -> StG Expr
replUids t = mapEx true rep where
    rep (v@Vbl{name=Local{uid}}) = case t.lookupI uid of
        Just nid -> stio (Right v.{name <- QName.{uid=nid}})
        Nothing  -> stio (Right v)
    rep x = stio (Left x)


{--
    Copy an expression, all patterns will be refreshed and references
    renamed accordingly. The new expression will have the same type as the old one.
    -}
copyExpr :: Maybe Position -> TreeMap Int Int -> Expr -> StG Expr    
copyExpr mbp t x = mapEx false (copy t) x where
    ins t (k::Int,v) = TreeMap.insertI t k v

    copyAlt tree CAlt{pat,ex} = do
        let qnames  = patNames pat
            olduids = map QName.uid qnames
        pat <- pReturn pat
        pat <- maybe (return pat) (flip pPos pat) mbp
        let newuids = map QName.uid (patNames pat)
            ntree   = fold ins tree (zip olduids newuids)
        ex <- copyExpr mbp ntree ex
        mapM_ (copySym mbp ntree) qnames
        stio CAlt{pat,ex}
    
    copy tree (Lam{pat,ex,typ}) = do
        patex ← copyAlt tree CAlt{pat, ex}
        stio (Right Lam{pat=patex.pat,ex=patex.ex,typ})
    copy tree (Case ckind x alts typ) = do
        x    <- copyExpr mbp tree x
        alts <- mapSt (copyAlt tree) alts
        stio (Right (Case ckind x alts typ))
    copy tree (Let{env,ex,typ}) = do
        let olduids = map QName.uid env
        newuids <- sequence (map (const uniqid) olduids)
        let ntree = fold ins tree (zip olduids newuids)
        nenv <- mapSt (copySym mbp ntree) env
        ex   <- copyExpr mbp ntree ex
        stio (Right Let{env=nenv,ex,typ})
    copy tree (v@Vbl{name=Local{uid}}) = case tree.lookupI uid of
            Just nid -> stio (Right v.{pos = npos, name <- QName.{uid=nid}})
            Nothing  -> stio (Right v.{pos = npos})
        where
            npos = (fromMaybe v.pos mbp).change VARID v.name.base
    copy _ (x@Ann{}) = stio (Left x)        -- keep annotations!
    copy _ x
        | x.{pos?}, 
          Just pos ← mbp 
                    = pure (Left x.{pos = Pos{
                            first = x.pos.first.{line=pos.first.line, 
                                                 col=pos.first.col,
                                                 offset=pos.first.offset}, 
                            last = x.pos.last.{line=pos.last.line,
                                                col=pos.last.col,
                                                offset=pos.last.offset}}})

        | otherwise = pure (Left x)


--- copy a local symbol
copySym mbp tree qname = do
    sym <- findV qname
    case tree.lookupI (Symbol.sid sym) of
        Just nuid ->  do
                mex <- maybeST sym.expr id
                mbx <- maybeST mex (copyExpr mbp tree)
                let name = sym.name.{uid=nuid}
                    npos = (fromMaybe sym.pos mbp).change VARID name.base
                    -- nsig = if sym.anno || isNothing sym.expr then sym.typ else pSigma
                    nsym = sym.{pos = npos, name, expr = fmap return mbx, sid = nuid}
                enter nsym
                stio name
        Nothing -> Prelude.error ("Can't find sid " ++ show sym.sid ++ " for name " ++ show sym.name)


maybeST (Just f) act = do
    f <- act f
    stio (Just f)
maybeST Nothing _ = stio Nothing


{--
    @untypeExpr x@ removes all type information from @x@

    All let bound symbols inside the expression are set back to state 'Unchecked'
    and the type to 'pSigma', unless there was an annotation.
-}
untypeExpr x = mapEx true unty x
    where
        untySy qn = do
            sym <- findV qn
            changeSym sym.{typ   = if sym.anno then sym.typ else pSigma,
                           state = Unchecked}
        unty (x@Ann{}) = return (Left x)    -- keep type signatures intact
        unty (x@Lam{pat}) = do
            foreach (patNames pat) untySy
            return (Left x.{typ=Nothing})
        unty (x@Case{alts}) = do
            foreach [ q | alt <- alts, q <- patNames (CAlt.pat alt) ] untySy
            return (Left x.{typ=Nothing})
        unty (x@Let{env}) = do
            foreach env untySy
            return (Left x.{typ=Nothing})
        unty x = return (Left x.{typ=Nothing})


{-- @ourGlobalFuns mtree ex@ returns the set of all global symbols mentioned in /ex/
 *
 * _mtree_ is a 'Tree' 'String' ['Symbol'] mapping unqualified member names to
 * potential actual members. This is so that if we see @x.foo@ in some expression
 * we make sure to first type check @T.foo@ forall our _T_ that has a _foo_.
 -}
ourGlobalFuns :: TreeMap String [Symbol] -> Expr -> StG (TreeMap Symbol ())
ourGlobalFuns mtree ex = foldEx true collect empty ex where
    collect acc (Vbl {name})
        | Local {} <- name = stio (Left acc)
        | otherwise = do
            sym <- findVD name
            case sym of
                SymV{} -> stio (Left (acc `including` sym))
                sonst  -> stio (Left acc)
    collect acc (Mem {member})
        | Just list <- TreeMap.lookupS mtree member.value
                    = stio (Left (fold including acc list))
        | otherwise = stio (Left acc)
    collect acc _ = stio (Left acc)


symVD f g sym = case sym of
    SymV{} -> f sym
    SymD{} -> g sym
    other  -> Prelude.error (sym.name.base ++ " is neither SymV nor SymD")


{--
 * [usage] @fundep expr@
 * [returns] a list of our 'QName's that are directly mentioned in _ex_
 -}
fundep (SymV {name, expr=Just dx})  = do
    g <- getST
    x <- dx 
    deptree <- ourGlobalFuns empty x
    let dep = [ Symbol.name sy | sy <- keys deptree, g.our sy.name ]
    stio (name, dep)
fundep (SymV {name, expr=Nothing}) = stio (name, [])
fundep other = do
   g <- getST
   E.fatal other.pos (text("fundep: strange symbol: " ++ other.nice g))


--- find all our 'SymV' symbols
allourvars :: Global -> [Symbol]
allourvars g = 
    let collectedenvs = g.thisTab : [ Symbol.env sy | sy <- values g.thisTab, Symbol.{env?} sy ]
    in [ v | env <- collectedenvs, v@SymV {name} <- values env, g.our name]


--- find all 'SymV' symbols, be they ours or not
allvars = do
    g <- getST
    let envEnvs env = env : [Symbol.env sy | sy <- values env, Symbol.{env?} sy]
        packEnvs = values g.packages
        collectedenvs = fold (++) [] (map envEnvs packEnvs)
        collectedvars = [ v | env::Symtab <- collectedenvs, v@SymV {name} <- values env]
    stio collectedvars


{--
 * find a name of the same shape that is not used yet
 -}
unusedName qn g
    | Nothing <- Global.find g qn  = qn
    | otherwise = head [ m | b <- allBinders g, m = qn.qual (qn.base ++ b), isNothing (g.find m)]

--- find a name of the same shape with a suffix of upper case ASCII letters that is not used yet
unusedASCIIName qn g
    | Nothing <- Global.find g qn  = qn
    | otherwise = head [ m | b <- map _.toUpperCase allAsciiBinders, 
                                let m = qn.qual (qn.base ++ b), isNothing (g.find m)]

{--
 * convenience 'StIO' printing functions
 -}
print x = do
    g <- getSTT
    liftIO $ g.printer.print (display x)
    return ()


println x = do
    g <- getSTT
    liftIO $ g.printer.println (display x)
    return ()


{--
 * [usage] @returnType rho@
 * [returns] the 'Tau' type of the last 'Rho' in a sequence of 'RhoFun' (which is a 'RhoTau')
 * and the list of argument types (which are 'Sigma's)
 *
 * By definition, the length of the list of argument types is the arity of the function,
 * yet the actual number of formal arguments derived from the
 * number of nested lambdas (given in 'Symbol.depth') may be smaller and even 0.
 -}
returnType (RhoFun _ sig rho) =  (tau, sig:sigs) where (tau, sigs) = returnType rho
returnType (RhoTau _ tau)     =  (tau, [])


{--
 * [usage] @returnTypeN n rho@
 * [returns] the 'Rho' type that would be the result of applying @n@ arguments to a function
 * of the type given in @rho@ and the @n@ 'Sigma' types of the arguments.
 * [undefined] if @n@ is greater than the 'arity' of the type. This should never happen after
 * type check and hints at a compiler error.
 -}
returnTypeN 0 !rho = (tauRho rho, [])
returnTypeN n (RhoFun _ sig rho) = (r, sig:sigs) where (r, sigs) = returnTypeN (n-1) rho
returnTypeN n rho = Prelude.error "returnTypeN: too many arguments"




--- tell if a given type is a java type
isJavaType (TCon {name}) = do
    sym <- findT name
    stio (isJust sym.nativ)
isJavaType (tapp@TApp _ _) = isJavaType (head tapp.flat)
isJavaType (Meta tv) | tv.isFlexi = do
    g <- getST
    let mbt = g.bound tv
    maybe (stio false) isJavaType mbt
isJavaType _ = stio false


{-- Arity of a 'Symbol' based on its type -}
arity sym = case returnType (Symbol.typ sym).rho of
    (_, xs) -> length xs


isList (TApp (TCon {name = TName p "[]"}) ty) | p == pPreludeBase = Just ty
isList _                                                          = Nothing


isUnit (ty@TCon {name = TName p "()"}) | p == pPreludeBase = Just ty
isUnit _                                               = Nothing


isMaybe (TApp (TCon {name = TName p "Maybe"}) ty) | p == pPreludeBase = Just ty
isMaybe _ = Nothing


{-- 
    @isException g tau@ returns 'Just' (_left_, _right_) if and only if the following
    conditions both hold:
    - _tau_ is of the form 'Either' _left_ _right_
    - _left_ is ('isThrowable' g _left_) or ('isException' g _left_) is not 'Nothing'    
-} 
isException g (TApp (TApp con ex) ty) 
    | isEither con, isEx ex = Just (ex, ty)
    where
        isEither  TCon {name = TName p1 "Either"} = p1 == pPreludeBase
        isEither  _                               = false 
        isEx (TApp (TApp con a) b) = isEither con && isEx a && isThrowable g b
        isEx x = isThrowable g x 
isException g _ = Nothing


{-- 
    @true@ if and only if in environment _g_, 
    type _ty_ denotes a sub type of @java.lang.Throwable@ 
-}
isThrowable g ty = case instTauSym ty g of
        Just SymT{nativ=Just x} -> x == "java.lang.Throwable" 
                                     || "java.lang.Throwable" `elem` supersOfNativ x g
        other -> false


{--
    > isClassMember sym = Just SymC{}
    If  @sym@ is a class member, return the class it belongs to,
    otherwise 'Nothing'. 
-}
isClassMember SymV{name} g
    | MName{tynm} <- name,
      found <- Global.findit g tynm,
      Just SymC{} <- found = found 
isClassMember _ _ = Nothing


-- sigma manipulation stuff
